-
Notifications
You must be signed in to change notification settings - Fork 417
feat(input_pipeline): Add support for chunking long sequences instead truncation #2354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feature! And great unit tests! Just some minor comments.
|
Thanks for the great feedback! I've pushed the changes addressing all your points:
to: @aireenmei |
|
Looks like the github actions tests need to be triggered by a maintainer. Please take a look at the test failures. You can also run them locally |
- Added comment that TokenizeAndChunk removes all columns except the text_column
- Modified _grain_tokenizer.py with latest changes
- Added note that use_truncation=False is only available in grain's pretrain preprocessing pipeline
- Move feature_names, sequence_length, add_bos, add_eos, and tokenizer to TokenizerTransformBase - Consolidate initialization logic in base class __post_init__ - Simplify TokenizeAndTrim and TokenizeAndChunk by removing duplicate parameters - Add common _encode method to eliminate code duplication - Maintain backward compatibility and specialized behavior for each class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! Pls make sure the tests are passing before merging
|
@aireenmei yeah, I think all look good for now. thanks for the detailed reviews! |
|
I think |
The github runner uses requirements_with_jax_ai_image.txt. Could you change it with the PR and see if that fix the test? |
|
Hi @bzantium can you rebase your pr and run the tests again. There were some changes to testing infra last week that require a rebase for successful image builds in unit tests. |
…os since they are used at tokenizer itself not tokenizer trasform
|
@aireenmei @Rohan-Bierneni @SurbhiJainUSC |
Thanks for the update! If grain==0.2.12 doesn't work, let me know if we need to request a new release from the grain team. |
|
@aireenmei thanks for the fast reply! I've checked current implementation works fine with grain==0.12.2 on tpu-v6e but I think if it's possible for grain team to release new version, that would be nice in terms of this PR (can make code more neat as previous) and other upcoming features that I want to add soon. |
|
https://pypi.org/project/grain/0.2.13/ |
|
I've successfully tested this code on my tpus... why test failure happened...?! |
|
The corresponding tpu test are passing for MaxText head. Could you rebase again? Maxtext gpu unit test seems to not pick up the new Grain version. But at head it has picked up the newest one (https://github.com/AI-Hypercomputer/maxtext/actions/runs/18940014406/job/54076248442). I manually added pull ready label, not sure if this can make the test run automatically once you rebase and repush the branch |
|
@aireenmei |
This PR introduces "chunking" as an alternative to "truncation" in the Grain input pipeline.
Previously, the
TokenizeAndTrimoperation (MapTransform) would truncate any document longer thanmax_target_length, discarding all subsequent tokens. This change introduces a newTokenizeAndChunkoperation (FlatMapTransform) that splits a single long document into multiple training examples, each no longer thanmax_target_length.This new behavior is controlled by a new configuration flag,
use_truncation.Why is this change being made?
The default truncation behavior is highly data-inefficient for corpora with many long documents (like C4). It wastes significant amounts of data, compute, and storage, and may bias the model by only ever training on the beginning of documents.
The problem being solved and any relevant context:
This PR solves the problem of data loss during tokenization for long sequences. By using a 1:N
FlatMapTransform, we can map one long input document to a list of multiple, valid training chunks, ensuring 100% of the tokenized data is used.Why this is a good solution:
This solution is efficient and flexible. It utilizes the
FlatMapTransformprovided by Grain, which is designed for this 1:N mapping. It is also fully backwards-compatible, as the new chunking behavior is "opt-in" by settinguse_truncation = Falsein the config. The default behavior remains truncation.Some information about the specific implementation:
_grain_tokenizer.py: A newTokenizeAndChunkclass has been added. It inherits fromgrain.experimental.FlatMapTransformand implements theflat_mapmethod to split a list of token IDs into multiple chunks._grain_data_processing.py: Thepretrain_preprocessing_pipelinefunction has been updated with a conditional check forconfig.use_truncation:True, it uses the existingdataset.map(TokenizeAndTrim(...)).False, it usesdataset.apply(TokenizeAndChunk(...)).dataset.apply()method and support forFlatMapTransformare recent features in Grain. This PR requires a version of Grain installed directly from the main branch.Shortcomings of the solution and possible future improvements.
The
max_fan_outattribute inTokenizeAndChunkis set with a class-level default (2048). If a document is exceptionally long and produces more chunks than this, it will error. This could be exposed as a configuration option in the future if needed.Tests
This change is tested with a new, self-contained unit test file:
tests/tokenizer_transform_test.py.MockTokenizerto provide known, deterministic tokenization ("a b c" -> [1, 2, 3]).grain.MapDataset.sourcewith a small, known dataset to test edge cases (short text, long text, and multi-chunk text).test_tokenize_and_trim: Verifies the original 1:1 truncation logic is correct.test_tokenize_and_chunk: Verifies the new 1:N chunking logic (e.g., an input with 7 tokens andmax_len=5correctly produces two new examples with 5 and 2 tokens).test_trim_and_pad_chaining: Verifies that the output ofTokenizeAndTrimcan be correctly chained into a subsequentPadToMaxLengthtransform.test_chunk_and_pad_chaining: Verifies that all outputs fromTokenizeAndChunkare correctly chained intoPadToMaxLength(e.g., both the 5-token chunk and the 2-token chunk are correctly padded).To reproduce, you can run the new test file directly:
Fixes: #2344
Checklist
Before submitting this PR, please make sure (put X in square brackets):